import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from math import pi
import gc
torch.manual_seed(2020)

class CtsConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes, radius, normalize_attention=False, 
                 layer_name=None):
        super(CtsConv, self).__init__()
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.kernel_sizes = kernel_sizes
        self.kernel = torch.nn.parameter.Parameter(
            self.init_kernel(in_channels, out_channels, kernel_sizes), requires_grad=True
        )
        if layer_name is not None:
            self.register_parameter(layer_name, self.kernel)
        self.radius = radius
        self.normalize_attention=normalize_attention
        
    def init_kernel(self, in_channels, out_channels, kernel_sizes):
        kernel = torch.rand(out_channels, in_channels, *kernel_sizes)
        kernel -= 0.5
        k = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float))
        kernel *= 1 * k
        return kernel
    
    def Lambda(self, vec): 
        # Sphere to Grid
        """
        xy = vec[...,0:2] #Spatial Coord
        # Convert to Polar
        r = torch.sqrt(torch.sum(xy ** 2, -1))
        # Stretch Radius
        s = self.stretch(xy[...,0], xy[...,1])
        # Convert to Rectangular
        out = [xy[...,0] * s, xy[...,1] * s, vec[...,2]]
        out = torch.stack(out, -1)
        """
        
        x, y = vec[...,0], vec[...,1]
        x_out, y_out = self.map_polar_sqr(x, y)
        
        out = torch.stack([x_out, y_out, vec[...,2]], axis=-1)
        return out
    
    def map_polar_sqr(self, x, y, epsilon=1e-9):

        r = torch.sqrt(x ** 2 + y ** 2 + epsilon)

        cond1 = (x == 0.) & (y == 0.)
        cond2 = (torch.abs(y) <= torch.abs(x)) & (~cond1)
        cond3 = ~(cond1 | cond2)
        
        x_out = torch.zeros(*x.shape, device=self.kernel.device)
        y_out = torch.zeros(*x.shape, device=self.kernel.device)
        
        x_out[cond1] = 0.
        y_out[cond1] = 0.

        x_out[cond2] = torch.sign(x[cond2]) * r[cond2]
        y_out[cond2] = 4 / pi * torch.sign(x[cond2]) * r[cond2] * torch.atan(y[cond2] / x[cond2])

        x_out[cond3] = 4 / pi * torch.sign(y[cond3]) * r[cond3] * torch.atan(x[cond3] / y[cond3])
        y_out[cond3] = torch.sign(y[cond3]) * r[cond3]

        return x_out, y_out
    
    def InterpolateKernelUnit(self, kernel, pos):
        """
        @kernel: [c_out, c_in=feat_dim, x, y, z] -> [batch, C=c_out*c_in, x, y, z]
        @pos: [batch, num, 3] -> [batch, num, 1, 1, 3]
        
        return out: [batch, C=c_out*c_in, num, 1, 1] -> [batch, num, c_out, c_in]
        """
        
        kernels = kernel.reshape(-1, *kernel.shape[2:]).unsqueeze(0)
        kernels = kernels.expand((pos.shape[0], *kernels.shape[1:]))
        grid = pos.unsqueeze(2).unsqueeze(2)
        out = F.grid_sample(kernels, grid, padding_mode='zeros', 
                            mode='bilinear', align_corners=False)
        out = out.squeeze(-1).squeeze(-1).permute(0, 2, 1)
        out = out.reshape(*pos.shape[0:2], *kernel.shape[0:2])
        
        return out
    
    def GetAttention(self, relative_field):
        r = torch.sum(relative_field ** 2, axis=-1)
        return torch.relu((1 - r) ** 3).unsqueeze(-1)
    
    def ContinuousConvUnit(
        self, kernel, field, center, field_feat, 
        field_mask, ctr_feat=None, normalize_attention=False
    ):
        """
        @kernel: [1, feat_dim, depth=3, width=3, height=3]
        @field: [batch, num, pos_dim=3]
        @center: [batch, 1, pos_dim=3]
        @field_feat: [batch, num, c_in=feat_dim]
        @ctr_feat: [batch, 1, feat_dim]
        @field_mask: [batch, num, 1]
        """
        relative_field = (field - center) / self.radius
        
        attention = self.GetAttention(relative_field) * field_mask
        # attention: [batch, num, 1]
        
        psi = torch.sum(attention, axis=1) if normalize_attention else 1
        
        scaled_field = self.Lambda(relative_field)
        
        kernel_on_field = self.InterpolateKernelUnit(kernel, scaled_field)
        # kernel_on_field: [batch, num, c_out, c_in]
        
        out = torch.einsum('bnoi,bni->bo', kernel_on_field, field_feat*attention)
        # out: [batch, c_out]
        
        return out / psi
    
    def InterpolateKernel(self, kernel, pos):
        """
        @kernel_sizes = [kernel_size, kernel_size, kernel_size]
        @kernel: [c_out, c_in=feat_dim, kernel_size, kernel_size, kernel_size] 
                  -> [batch, C=c_out*c_in, x, y, z]
        @pos: [batch, num_m, num_n, 3] -> [batch, num_m, num_n, 1, 3]
        
        return out: [batch, C=c_out*c_in, num_m, num_n, 1] -> [batch, num_m, num_n, c_out, c_in]
        """
        
        kernels = kernel.reshape(-1, *kernel.shape[2:]).unsqueeze(0)
        kernels = kernels.expand((pos.shape[0], *kernels.shape[1:]))
        # kernels: [batch, C=c_out*c_in, x, y, z]
                 
        out = F.grid_sample(kernels, pos.unsqueeze(-2), padding_mode='zeros', 
                            mode='bilinear', align_corners=False)
        del kernels
        # gc.collect()
        # if torch.cuda.is_available():
        #     torch.cuda.empty_cache()
        # pos.unsqueeze(-2): [batch, num_m, num_n, 1, 3]
        out = out.squeeze(-1).permute(0, 2, 3, 1)
        out = out.reshape(*pos.shape[:-1], *kernel.shape[0:2])
        
        return out
    
    def ContinuousConv(
        self, kernel, field, center, field_feat, 
        field_mask, ctr_feat=None
    ):
        """
        @kernel: [c_out, c_in=feat_dim, kernel_size, kernel_size, kernel_size]
        @field: [batch, num_n, pos_dim=3] -> [batch, 1, num_n, pos_dim]
        @center: [batch, num_m, pos_dim=3] -> [batch, num_m, 1, pos_dim]
        @field_feat: [batch, num_n, c_in=feat_dim] -> [batch, 1, num_n, c_in]
        @ctr_feat: [batch, 1, feat_dim]
        @field_mask: [batch, num_n, 1]
        """
        
        relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
        # relative_field: [batch, num_m, num_n, pos_dim]
        
        attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1).unsqueeze(-1)
        # attention: [batch, num_m, num_n, 1]
        
        psi = torch.sum(attention, axis=2) + 1 if self.normalize_attention else 1
        
        scaled_field = self.Lambda(relative_field)
        # scaled_field: [batch, num_m, num_n, pos_dim]
        del relative_field
        # gc.collect()
        # if torch.cuda.is_available():
        #     torch.cuda.empty_cache()
        
        kernel_on_field = self.InterpolateKernel(kernel, scaled_field)
        # kernel_on_field: [batch, num_m, num_n, c_out, c_in]
        
        
        out = torch.einsum('bmnoi,bmni->bmo', kernel_on_field, field_feat.unsqueeze(1)*attention)
        # unsqueezed_feat: [batch, 1, num_n, c_in]
        # out: [batch, num_m, c_out]
        
        return out / psi
    
    def forward(
        self, field, center, field_feat, 
        field_mask, ctr_feat=None
    ):
        out = self.ContinuousConv(
            self.kernel, field, center, field_feat, field_mask, ctr_feat
        )
        return out
        
    def extra_repr(self):
        return 'input_channels={}, output_channels={}, kernel_size={}'.format(
            self.in_channels, self.out_channels, self.kernel_sizes
        )


class RelCtsConv(CtsConv):
    def __init__(self, in_channels, out_channels, kernel_sizes, radius, normalize_attention=False, 
                 layer_name=None):
        super(RelCtsConv, self).__init__(in_channels, out_channels, kernel_sizes, 
                                         radius, normalize_attention, layer_name)
        
    def ContinuousConv(
        self, kernel, field, center, field_feat, 
        field_mask, ctr_feat=None
    ):
        """
        @kernel: [c_out, c_in=feat_dim, kernel_size, kernel_size, kernel_size]
        @field: [batch, num_n, pos_dim=3] -> [batch, 1, num_n, pos_dim]
        @center: [batch, num_m, pos_dim=3] -> [batch, num_m, 1, pos_dim]
        @field_feat: [batch, num_n, c_in=feat_dim] -> [batch, 1, num_n, c_in]
        @ctr_feat: [batch, 1, feat_dim]
        @field_mask: [batch, num_n, 1]
        """
        
        relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
        # relative_field: [batch, num_m, num_n, pos_dim]
        
        attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1).unsqueeze(-1)
        # attention: [batch, num_m, num_n, 1]
        
        psi = torch.sum(attention, axis=2) + 1 if self.normalize_attention else 1
        
        scaled_field = self.Lambda(relative_field)
        # scaled_field: [batch, num_m, num_n, pos_dim]
        del relative_field
        # gc.collect()
        # if torch.cuda.is_available():
        #     torch.cuda.empty_cache()
        
        kernel_on_field = self.InterpolateKernel(kernel, scaled_field)
        # kernel_on_field: [batch, num_m, num_n, c_out, c_in]
        
        field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2)

        out = torch.einsum('bmnoi,bmni->bmo', kernel_on_field, field_feat*attention)
        # unsqueezed_feat: [batch, 1, num_n, c_in]
        # out: [batch, num_m, c_out]
        
        return out / psi
    
    
    def forward(
        self, field, center, field_feat, 
        field_mask, ctr_feat
    ):
        field_feat
        out = self.ContinuousConv(
            self.kernel, field, center, field_feat, field_mask, ctr_feat
        )
        return out
        
    def extra_repr(self):
        return 'input_channels={}, output_channels={}, kernel_size={}'.format(
            self.in_channels, self.out_channels, self.kernel_sizes
        )

  
